/* -*- c -*- */

/*
 * The purpose of this module is to add faster sort functions
 * that are type-specific.  This is done by altering the
 * function table for the builtin descriptors.
 *
 * These sorting functions are copied almost directly from numarray
 * with a few modifications (complex comparisons compare the imaginary
 * part if the real parts are equal, for example), and the names
 * are changed.
 *
 * The original sorting code is due to Charles R. Harris who wrote
 * it for numarray.
 */

/*
 * Quick sort is usually the fastest, but the worst case scenario can
 * be slower than the merge and heap sorts.  The merge sort requires
 * extra memory and so for large arrays may not be useful.
 *
 * The merge sort is *stable*, meaning that equal components
 * are unmoved from their entry versions, so it can be used to
 * implement lexigraphic sorting on multiple keys.
 *
 * The heap sort is included for completeness.
 */

#define NPY_NO_DEPRECATED_API NPY_API_VERSION

#include "npy_sort.h"
#include "npysort_common.h"
#include <stdlib.h>

#define NOT_USED NPY_UNUSED(unused)
#define PYA_QS_STACK 100
#define SMALL_QUICKSORT 15
#define SMALL_MERGESORT 20
#define SMALL_STRING 16


/*
 *****************************************************************************
 **                            NUMERIC SORTS                                **
 *****************************************************************************
 */


/**begin repeat
 *
 * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
 *         LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
 *         CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA#
 * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
 *         longlong, ulonglong, half, float, double, longdouble,
 *         cfloat, cdouble, clongdouble, datetime, timedelta#
 * #type = npy_bool, npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int,
 *         npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong,
 *         npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat,
 *         npy_cdouble, npy_clongdouble, npy_datetime, npy_timedelta#
 */

static void
mergesort0_@suff@(@type@ *pl, @type@ *pr, @type@ *pw)
{
    @type@ vp, *pi, *pj, *pk, *pm;

    if (pr - pl > SMALL_MERGESORT) {
        /* merge sort */
        pm = pl + ((pr - pl) >> 1);
        mergesort0_@suff@(pl, pm, pw);
        mergesort0_@suff@(pm, pr, pw);
        for (pi = pw, pj = pl; pj < pm;) {
            *pi++ = *pj++;
        }
        pi = pw + (pm - pl);
        pj = pw;
        pk = pl;
        while (pj < pi && pm < pr) {
            if (@TYPE@_LT(*pm, *pj)) {
                *pk++ = *pm++;
            }
            else {
                *pk++ = *pj++;
            }
        }
        while(pj < pi) {
            *pk++ = *pj++;
        }
    }
    else {
        /* insertion sort */
        for (pi = pl + 1; pi < pr; ++pi) {
            vp = *pi;
            pj = pi;
            pk = pi - 1;
            while (pj > pl && @TYPE@_LT(vp, *pk)) {
                *pj-- = *pk--;
            }
            *pj = vp;
        }
    }
}


NPY_NO_EXPORT int
mergesort_@suff@(void *start, npy_intp num, void *NOT_USED)
{
    @type@ *pl, *pr, *pw;

    pl = start;
    pr = pl + num;
    pw = malloc((num/2) * sizeof(@type@));
    if (pw == NULL) {
        return -NPY_ENOMEM;
    }
    mergesort0_@suff@(pl, pr, pw);

    free(pw);
    return 0;
}


static void
amergesort0_@suff@(npy_intp *pl, npy_intp *pr, @type@ *v, npy_intp *pw)
{
    @type@ vp;
    npy_intp vi, *pi, *pj, *pk, *pm;

    if (pr - pl > SMALL_MERGESORT) {
        /* merge sort */
        pm = pl + ((pr - pl) >> 1);
        amergesort0_@suff@(pl, pm, v, pw);
        amergesort0_@suff@(pm, pr, v, pw);
        for (pi = pw, pj = pl; pj < pm;) {
            *pi++ = *pj++;
        }
        pi = pw + (pm - pl);
        pj = pw;
        pk = pl;
        while (pj < pi && pm < pr) {
            if (@TYPE@_LT(v[*pm], v[*pj])) {
                *pk++ = *pm++;
            }
            else {
                *pk++ = *pj++;
            }
        }
        while(pj < pi) {
            *pk++ = *pj++;
        }
    }
    else {
        /* insertion sort */
        for (pi = pl + 1; pi < pr; ++pi) {
            vi = *pi;
            vp = v[vi];
            pj = pi;
            pk = pi - 1;
            while (pj > pl && @TYPE@_LT(vp, v[*pk])) {
                *pj-- = *pk--;
            }
            *pj = vi;
        }
    }
}


NPY_NO_EXPORT int
amergesort_@suff@(void *v, npy_intp *tosort, npy_intp num, void *NOT_USED)
{
    npy_intp *pl, *pr, *pw;

    pl = tosort;
    pr = pl + num;
    pw = malloc((num/2) * sizeof(npy_intp));
    if (pw == NULL) {
        return -NPY_ENOMEM;
    }
    amergesort0_@suff@(pl, pr, v, pw);
    free(pw);

    return 0;
}

/**end repeat**/


/*
 *****************************************************************************
 **                             STRING SORTS                                **
 *****************************************************************************
 */


/**begin repeat
 *
 * #TYPE = STRING, UNICODE#
 * #suff = string, unicode#
 * #type = npy_char, npy_ucs4#
 */

static void
mergesort0_@suff@(@type@ *pl, @type@ *pr, @type@ *pw, @type@ *vp, size_t len)
{
    @type@ *pi, *pj, *pk, *pm;

    if ((size_t)(pr - pl) > SMALL_MERGESORT*len) {
        /* merge sort */
        pm = pl + (((pr - pl)/len) >> 1)*len;
        mergesort0_@suff@(pl, pm, pw, vp, len);
        mergesort0_@suff@(pm, pr, pw, vp, len);
        @TYPE@_COPY(pw, pl, pm - pl);
        pi = pw + (pm - pl);
        pj = pw;
        pk = pl;
        while (pj < pi && pm < pr) {
            if (@TYPE@_LT(pm, pj, len)) {
                @TYPE@_COPY(pk, pm, len);
                pm += len;
                pk += len;
            }
            else {
                @TYPE@_COPY(pk, pj, len);
                pj += len;
                pk += len;
            }
        }
        @TYPE@_COPY(pk, pj, pi - pj);
    }
    else {
        /* insertion sort */
        for (pi = pl + len; pi < pr; pi += len) {
            @TYPE@_COPY(vp, pi, len);
            pj = pi;
            pk = pi - len;
            while (pj > pl && @TYPE@_LT(vp, pk, len)) {
                @TYPE@_COPY(pj, pk, len);
                pj -= len;
                pk -= len;
            }
            @TYPE@_COPY(pj, vp, len);
        }
    }
}


NPY_NO_EXPORT int
mergesort_@suff@(void *start, npy_intp num, void *varr)
{
    PyArrayObject *arr = varr;
    size_t elsize = PyArray_ITEMSIZE(arr);
    size_t len = elsize / sizeof(@type@);
    @type@ *pl, *pr, *pw, *vp;
    int err = 0;

    /* Items that have zero size don't make sense to sort */
    if (elsize == 0) {
        return 0;
    }

    pl = start;
    pr = pl + num*len;
    pw = malloc((num/2) * elsize);
    if (pw == NULL) {
        err = -NPY_ENOMEM;
        goto fail_0;
    }
    vp = malloc(elsize);
    if (vp == NULL) {
        err = -NPY_ENOMEM;
        goto fail_1;
    }
    mergesort0_@suff@(pl, pr, pw, vp, len);

    free(vp);
fail_1:
    free(pw);
fail_0:
    return err;
}


static void
amergesort0_@suff@(npy_intp *pl, npy_intp *pr, @type@ *v, npy_intp *pw, size_t len)
{
    @type@ *vp;
    npy_intp vi, *pi, *pj, *pk, *pm;

    if (pr - pl > SMALL_MERGESORT) {
        /* merge sort */
        pm = pl + ((pr - pl) >> 1);
        amergesort0_@suff@(pl, pm, v, pw, len);
        amergesort0_@suff@(pm, pr, v, pw, len);
        for (pi = pw, pj = pl; pj < pm;) {
            *pi++ = *pj++;
        }
        pi = pw + (pm - pl);
        pj = pw;
        pk = pl;
        while (pj < pi && pm < pr) {
            if (@TYPE@_LT(v + (*pm)*len, v + (*pj)*len, len)) {
                *pk++ = *pm++;
            }
            else {
                *pk++ = *pj++;
            }
        }
        while (pj < pi) {
            *pk++ = *pj++;
        }
    }
    else {
        /* insertion sort */
        for (pi = pl + 1; pi < pr; ++pi) {
            vi = *pi;
            vp = v + vi*len;
            pj = pi;
            pk = pi - 1;
            while (pj > pl && @TYPE@_LT(vp, v + (*pk)*len, len)) {
                *pj-- = *pk--;
            }
            *pj = vi;
        }
    }
}


NPY_NO_EXPORT int
amergesort_@suff@(void *v, npy_intp *tosort, npy_intp num, void *varr)
{
    PyArrayObject *arr = varr;
    size_t elsize = PyArray_ITEMSIZE(arr);
    size_t len = elsize / sizeof(@type@);
    npy_intp *pl, *pr, *pw;

    /* Items that have zero size don't make sense to sort */
    if (elsize == 0) {
        return 0;
    }

    pl = tosort;
    pr = pl + num;
    pw = malloc((num/2) * sizeof(npy_intp));
    if (pw == NULL) {
        return -NPY_ENOMEM;
    }
    amergesort0_@suff@(pl, pr, v, pw, len);
    free(pw);

    return 0;
}

/**end repeat**/


/*
 *****************************************************************************
 **                             GENERIC SORT                                **
 *****************************************************************************
 */


static void
npy_mergesort0(char *pl, char *pr, char *pw, char *vp, npy_intp elsize,
               PyArray_CompareFunc *cmp, PyArrayObject *arr)
{
    char *pi, *pj, *pk, *pm;

    if (pr - pl > SMALL_MERGESORT*elsize) {
        /* merge sort */
        pm = pl + (((pr - pl)/elsize) >> 1)*elsize;
        npy_mergesort0(pl, pm, pw, vp, elsize, cmp, arr);
        npy_mergesort0(pm, pr, pw, vp, elsize, cmp, arr);
        GENERIC_COPY(pw, pl, pm - pl);
        pi = pw + (pm - pl);
        pj = pw;
        pk = pl;
        while (pj < pi && pm < pr) {
            if (cmp(pm, pj, arr) < 0) {
                GENERIC_COPY(pk, pm, elsize);
                pm += elsize;
                pk += elsize;
            }
            else {
                GENERIC_COPY(pk, pj, elsize);
                pj += elsize;
                pk += elsize;
            }
        }
        GENERIC_COPY(pk, pj, pi - pj);
    }
    else {
        /* insertion sort */
        for (pi = pl + elsize; pi < pr; pi += elsize) {
            GENERIC_COPY(vp, pi, elsize);
            pj = pi;
            pk = pi - elsize;
            while (pj > pl && cmp(vp, pk, arr) < 0) {
                GENERIC_COPY(pj, pk, elsize);
                pj -= elsize;
                pk -= elsize;
            }
            GENERIC_COPY(pj, vp, elsize);
        }
    }
}


NPY_NO_EXPORT int
npy_mergesort(void *start, npy_intp num, void *varr)
{
    PyArrayObject *arr = varr;
    npy_intp elsize = PyArray_ITEMSIZE(arr);
    PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare;
    char *pl = start;
    char *pr = pl + num*elsize;
    char *pw;
    char *vp;
    int err = -NPY_ENOMEM;

    /* Items that have zero size don't make sense to sort */
    if (elsize == 0) {
        return 0;
    }

    pw = malloc((num >> 1) *elsize);
    vp = malloc(elsize);

    if (pw != NULL && vp != NULL) {
        npy_mergesort0(pl, pr, pw, vp, elsize, cmp, arr);
        err = 0;
    }

    free(vp);
    free(pw);

    return err;
}


static void
npy_amergesort0(npy_intp *pl, npy_intp *pr, char *v, npy_intp *pw,
                npy_intp elsize, PyArray_CompareFunc *cmp, PyArrayObject *arr)
{
    char *vp;
    npy_intp vi, *pi, *pj, *pk, *pm;

    if (pr - pl > SMALL_MERGESORT) {
        /* merge sort */
        pm = pl + ((pr - pl) >> 1);
        npy_amergesort0(pl, pm, v, pw, elsize, cmp, arr);
        npy_amergesort0(pm, pr, v, pw, elsize, cmp, arr);
        for (pi = pw, pj = pl; pj < pm;) {
            *pi++ = *pj++;
        }
        pi = pw + (pm - pl);
        pj = pw;
        pk = pl;
        while (pj < pi && pm < pr) {
            if (cmp(v + (*pm)*elsize, v + (*pj)*elsize, arr) < 0) {
                *pk++ = *pm++;
            }
            else {
                *pk++ = *pj++;
            }
        }
        while (pj < pi) {
            *pk++ = *pj++;
        }
    }
    else {
        /* insertion sort */
        for (pi = pl + 1; pi < pr; ++pi) {
            vi = *pi;
            vp = v + vi*elsize;
            pj = pi;
            pk = pi - 1;
            while (pj > pl && cmp(vp, v + (*pk)*elsize, arr) < 0) {
                *pj-- = *pk--;
            }
            *pj = vi;
        }
    }
}


NPY_NO_EXPORT int
npy_amergesort(void *v, npy_intp *tosort, npy_intp num, void *varr)
{
    PyArrayObject *arr = varr;
    npy_intp elsize = PyArray_ITEMSIZE(arr);
    PyArray_CompareFunc *cmp = PyArray_DESCR(arr)->f->compare;
    npy_intp *pl, *pr, *pw;

    /* Items that have zero size don't make sense to sort */
    if (elsize == 0) {
        return 0;
    }

    pl = tosort;
    pr = pl + num;
    pw = malloc((num >> 1) * sizeof(npy_intp));
    if (pw == NULL) {
        return -NPY_ENOMEM;
    }
    npy_amergesort0(pl, pr, v, pw, elsize, cmp, arr);
    free(pw);

    return 0;
}
